
package com.jstarcraft.ai.jsat.linear.vectorcollection;

import static com.jstarcraft.ai.jsat.utils.SystemInfo.LogicalCores;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.linear.VecPaired;
import com.jstarcraft.ai.jsat.math.OnLineStatistics;
import com.jstarcraft.ai.jsat.utils.ListUtils;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;

/**
 * A collection of common utility methods to perform on a
 * {@link VectorCollection}
 * 
 * @author Edward Raff
 */
public class VectorCollectionUtils {
    /**
     * Searches the given collection for the <tt>k</tt> nearest neighbors for every
     * data point in the given search list.
     * 
     * @param            <V0> the vector type in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection to search from
     * @param search     the vectors to search for
     * @param k          the number of nearest neighbors
     * @return The list of lists for all nearest neighbors
     */
    public static <V0 extends Vec, V1 extends Vec> List<List<? extends VecPaired<V0, Double>>> allNearestNeighbors(VectorCollection<V0> collection, List<V1> search, int k) {
        List<List<? extends VecPaired<V0, Double>>> results = new ArrayList<>(search.size());
        for (Vec v : search)
            results.add(collection.search(v, k));
        return results;
    }

    /**
     * Searches the given collection for the <tt>k</tt> nearest neighbors for every
     * data point in the given search list.
     * 
     * @param            <V0> the vector type in the collection
     * @param            <V1> the type of vector in the search array
     * @param collection the collection to search from
     * @param search     the vectors to search for
     * @param k          the number of nearest neighbors
     * @return The list of lists for all nearest neighbors
     */
    public static <V0 extends Vec, V1 extends Vec> List<List<? extends VecPaired<V0, Double>>> allNearestNeighbors(VectorCollection<V0> collection, V1[] search, int k) {
        return allNearestNeighbors(collection, Arrays.asList(search), k);
    }

    /**
     * Searches the given collection for the <tt>k</tt> nearest neighbors for every
     * data point in the given search list.
     * 
     * @param            <V0> the vector type in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection to search from
     * @param search     the vectors to search for
     * @param k          the number of nearest neighbors
     * @param threadpool the source of threads to perform the computation in
     *                   parallel
     * @return The list of lists for all nearest neighbors
     * @deprecated This will be deleted soon
     */
    public static <V0 extends Vec, V1 extends Vec> List<List<? extends VecPaired<V0, Double>>> allNearestNeighbors(final VectorCollection<V0> collection, List<V1> search, final int k, ExecutorService threadpool) {
        List<List<? extends VecPaired<V0, Double>>> results = new ArrayList<>(search.size());
        List<Future<List<List<? extends VecPaired<V0, Double>>>>> subResults = new ArrayList<>(LogicalCores);

        for (final List<V1> subSearch : ListUtils.splitList(search, LogicalCores)) {
            subResults.add(threadpool.submit(() -> {
                List<List<? extends VecPaired<V0, Double>>> subResult = new ArrayList<>(subSearch.size());

                for (Vec v : subSearch)
                    subResult.add(collection.search(v, k));

                return subResult;
            }));
        }

        try {
            for (List<List<? extends VecPaired<V0, Double>>> subResult : ListUtils.collectFutures(subResults))
                results.addAll(subResult);
        } catch (ExecutionException | InterruptedException ex) {
            Logger.getLogger(VectorCollectionUtils.class.getName()).log(Level.SEVERE, null, ex);
        }

        return results;
    }

    /**
     * Searches the given collection for the <tt>k</tt> nearest neighbors for every
     * data point in the given search list.
     * 
     * @param            <V0> the vector type in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection to search from
     * @param search     the vectors to search for
     * @param k          the number of nearest neighbors
     * @param parallel   {@code true} if multiple threads should be used to perform
     *                   clustering. {@code false} if it should be done in a single
     *                   threaded manner.
     * @return The list of lists for all nearest neighbors
     */
    public static <V0 extends Vec, V1 extends Vec> List<List<? extends VecPaired<V0, Double>>> allNearestNeighbors(final VectorCollection<V0> collection, List<V1> search, final int k, boolean parallel) {
        return ParallelUtils.streamP(search.stream(), parallel).map(v -> collection.search(v, k)).collect(Collectors.toList());
    }

    /**
     * Searches the given collection for all the neighbors within a distance of
     * <tt>radius</tt> for every data point in the given search list.
     * 
     * @param            <V0> the vector type in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection to search from
     * @param search     the vectors to search for
     * @param radius     the distance to search for neighbors
     * @param threadpool the source of threads to perform the computation in
     *                   parallel
     * @return The list of lists for all nearest neighbors
     */
    public static <V0 extends Vec, V1 extends Vec> List<List<? extends VecPaired<V0, Double>>> allEpsNeighbors(final VectorCollection<V0> collection, List<V1> search, final double radius, ExecutorService threadpool) {
        List<List<? extends VecPaired<V0, Double>>> results = new ArrayList<>(search.size());
        List<Future<List<List<? extends VecPaired<V0, Double>>>>> subResults = new ArrayList<>(LogicalCores);

        for (final List<V1> subSearch : ListUtils.splitList(search, LogicalCores)) {
            subResults.add(threadpool.submit(() -> {
                List<List<? extends VecPaired<V0, Double>>> subResult = new ArrayList<>(subSearch.size());

                for (Vec v : subSearch)
                    subResult.add(collection.search(v, radius));

                return subResult;
            }));
        }

        try {
            for (List<List<? extends VecPaired<V0, Double>>> subResult : ListUtils.collectFutures(subResults))
                results.addAll(subResult);
        } catch (ExecutionException | InterruptedException ex) {
            Logger.getLogger(VectorCollectionUtils.class.getName()).log(Level.SEVERE, null, ex);
        }

        return results;
    }

    /**
     * Searches the given collection for all the neighbors within a distance of
     * <tt>radius</tt> for every data point in the given search list.
     * 
     * @param            <V0> the vector type in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection to search from
     * @param search     the vectors to search for
     * @param radius     the distance to search for neighbors
     * @param parallel   {@code true} if multiple threads should be used to perform
     *                   clustering. {@code false} if it should be done in a single
     *                   threaded manner.
     * @return The list of lists for all nearest neighbors
     */
    public static <V0 extends Vec, V1 extends Vec> List<List<? extends VecPaired<V0, Double>>> allEpsNeighbors(final VectorCollection<V0> collection, List<V1> search, final double radius, boolean parallel) {
        return ParallelUtils.streamP(search.stream(), parallel).map(v -> collection.search(v, radius)).collect(Collectors.toList());
    }

    /**
     * Searches the given collection for the <tt>k</tt> nearest neighbors for every
     * data point in the given search list.
     * 
     * @param            <V0> the vector type in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection to search from
     * @param search     the vectors to search for
     * @param k          the number of nearest neighbors
     * @param threadpool the source of threads to perform the computation in
     *                   parallel
     * @return The list of lists for all nearest neighbors
     */
    public static <V0 extends Vec, V1 extends Vec> List<List<? extends VecPaired<V0, Double>>> allNearestNeighbors(final VectorCollection<V0> collection, V1[] search, final int k, ExecutorService threadpool) {
        return allNearestNeighbors(collection, Arrays.asList(search), k, threadpool);
    }

    /**
     * Computes statistics about the distance of the k'th nearest neighbor for each
     * data point in the <tt>search</tt> list.
     * 
     * @param            <V0> the type of vector in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection of vectors to query from
     * @param search     the list of vectors to search for
     * @param k          the nearest neighbor to use
     * @return the statistics for the distance of the k'th nearest neighbor from the
     *         query point
     */
    public static <V0 extends Vec, V1 extends Vec> OnLineStatistics getKthNeighborStats(VectorCollection<V0> collection, List<V1> search, int k) {
        OnLineStatistics stats = new OnLineStatistics();
        for (Vec v : search)
            stats.add(collection.search(v, k).get(k - 1).getPair());

        return stats;
    }

    /**
     * Computes statistics about the distance of the k'th nearest neighbor for each
     * data point in the <tt>search</tt> list.
     * 
     * @param            <V0> the type of vector in the collection
     * @param            <V1> the type of vector in the search array
     * @param collection the collection of vectors to query from
     * @param search     the array of vectors to search for
     * @param k          the nearest neighbor to use
     * @return the statistics for the distance of the k'th nearest neighbor from the
     *         query point
     */
    public static <V0 extends Vec, V1 extends Vec> OnLineStatistics getKthNeighborStats(VectorCollection<V0> collection, V1[] search, int k) {
        return getKthNeighborStats(collection, Arrays.asList(search), k);
    }

    /**
     * Computes statistics about the distance of the k'th nearest neighbor for each
     * data point in the <tt>search</tt> list.
     * 
     * @param            <V0> the type of vector in the collection
     * @param            <V1> the type of vector in the search collection
     * @param collection the collection of vectors to query from
     * @param search     the list of vectors to search for
     * @param k          the nearest neighbor to use
     * @param threadpool the source of threads to perform the computation in
     *                   parallel
     * @return the statistics for the distance of the k'th nearest neighbor from the
     *         query point
     */
    public static <V0 extends Vec, V1 extends Vec> OnLineStatistics getKthNeighborStats(final VectorCollection<V0> collection, List<V1> search, final int k, ExecutorService threadpool) {
        List<Future<OnLineStatistics>> futureStats = new ArrayList<Future<OnLineStatistics>>(LogicalCores);

        for (final List<V1> subSearch : ListUtils.splitList(search, LogicalCores)) {
            futureStats.add(threadpool.submit(new Callable<OnLineStatistics>() {

                public OnLineStatistics call() throws Exception {
                    OnLineStatistics stats = new OnLineStatistics();

                    for (Vec v : subSearch)
                        stats.add(collection.search(v, k).get(k - 1).getPair());

                    return stats;
                }
            }));
        }

        OnLineStatistics stats = new OnLineStatistics();
        try {
            for (OnLineStatistics subResult : ListUtils.collectFutures(futureStats))
                stats = OnLineStatistics.add(stats, subResult);
        } catch (ExecutionException ex) {
            Logger.getLogger(VectorCollectionUtils.class.getName()).log(Level.SEVERE, null, ex);
        } catch (InterruptedException ex) {
            Logger.getLogger(VectorCollectionUtils.class.getName()).log(Level.SEVERE, null, ex);
        }

        return stats;
    }

    /**
     * Computes statistics about the distance of the k'th nearest neighbor for each
     * data point in the <tt>search</tt> list.
     * 
     * @param            <V0> the type of vector in the collection
     * @param            <V1> the type of vector in the search array
     * @param collection the collection of vectors to query from
     * @param search     the array of vectors to search for
     * @param k          the nearest neighbor to use
     * @param threadpool the source of threads to perform the computation in
     *                   parallel
     * @return the statistics for the distance of the k'th nearest neighbor from the
     *         query point
     */
    public static <V0 extends Vec, V1 extends Vec> OnLineStatistics getKthNeighborStats(final VectorCollection<V0> collection, V1[] search, final int k, ExecutorService threadpool) {
        return getKthNeighborStats(collection, Arrays.asList(search), k, threadpool);
    }
}
